import random
import json
from PIL import Image


##Generating ICL input for vizwiz dataset
def vizwiz_icl(all_data, tokenizer, num_shot):

    prompt = '<img>../../{}</img>{} Answer:'

    sampled_data = random.sample(all_data, num_shot + 1)
    data = json.loads(sampled_data[0])

    image, question = data['image'], data[
        'question']

    few_shot_prompt = ''
    if num_shot > 0:
        few_shot_samples = sampled_data[1:]
        for sample in few_shot_samples:
            sample = json.loads(sample.strip())
            few_shot_prompt += prompt.format(
                sample['image'],
                sample['question']) + f" {sample['answer']}"

    final_question = 'First carefully understand the given examples. Then use the given image and answer the question in the same way as the examples. If the question can not be answered, respond unanswerable. ' + few_shot_prompt + prompt.format(image, question)
    final_question = few_shot_prompt + prompt.format(image, question)

    return tokenizer(final_question, return_tensors='pt', padding='longest')


##Generating ICL input for okvqa dataset
def okvqa_icl(all_data, tokenizer, num_shot):
    prompt = '<img>{}</img>{} Answer:'

    sampled_data = random.sample(all_data, num_shot + 1)
    data = json.loads(sampled_data[0])

    image, question = data['image'], data[
        'question']

    few_shot_prompt = ''
    if num_shot > 0:
        few_shot_samples = sampled_data[1:]
        for sample in few_shot_samples:
            sample = json.loads(sample.strip())
            few_shot_prompt += prompt.format(
                sample['image'],
                sample['question']) + f" {sample['answer']}"

    final_question = few_shot_prompt + prompt.format(image, question)
    return tokenizer(final_question, return_tensors='pt', padding='longest')


def format_input(cur_dataset, is_eval=False):
    
    if cur_dataset == "vizwiz":
        return format_vizwiz
    if cur_dataset == "okvqa":
        return format_okvqa
    if cur_dataset == "flower":
        return format_flower
    if cur_dataset == "cub":
        return format_cub

        
def format_vizwiz(cur_data):
    cur_data = json.loads(cur_data)
    prompt = '<img>../../{}</img>{} Answer:'
    return prompt.format(cur_data["image"], cur_data["question"]), cur_data["answer"], cur_data["question_id"]

def format_okvqa(cur_data):
    cur_data = json.loads(cur_data)
    prompt = '<img>{}</img>{} Answer:'
    return prompt.format(cur_data["image"], cur_data["question"]), cur_data["answer"], cur_data["question_id"]


def format_flower(cur_data):
    pos = cur_data["pos"]
    neg = cur_data["neg"]
    pos_label = cur_data["pos_label"]
    neg_label = cur_data["neg_label"]
    query = cur_data["query"]
    rand_num = random.randint(0,1)
    if rand_num == 0:
        pos_example = f"<img>{pos}</img>What is the type of flower in the image? A.{pos_label} B.{neg_label}\nAnswer with the option's letter from the given choice directly. Answer: A\n"
        neg_example = f"<img>{neg}</img>What is the type of flower in the image? A.{pos_label} B.{neg_label}\nAnswer with the option's letter from the given choice directly. Answer: B\n"
        cur_query = f"<img>{query}</img>What is the type of flower in the image? A.{pos_label} B.{neg_label}\nAnswer with the option's letter from the given choice directly. Answer:"
        query_label = "A"
        return pos_example + neg_example + cur_query, query_label, -1
    else:
        pos_example = f"<img>{pos}</img>What is the type of flower in the image? A.{neg_label} B.{pos_label}\nAnswer with the option's letter from the given choice directly. Answer: B\n"
        neg_example = f"<img>{neg}</img>What is the type of flower in the image? A.{neg_label} B.{pos_label}\nAnswer with the option's letter from the given choice directly. Answer: A\n"
        cur_query = f"<img>{query}</img>What is the type of flower in the image? A.{neg_label} B.{pos_label}\nAnswer with the option's letter from the given choice directly. Answer:"
        query_label = "B"
        return neg_example + pos_example + cur_query, query_label, -1
    

def format_cub(cur_data):
    pos = cur_data["pos"]
    neg = cur_data["neg"]
    pos_label = cur_data["pos_label"]
    neg_label = cur_data["neg_label"]
    query = cur_data["query"]
    rand_num = random.randint(0,1)
    if rand_num == 0:
        pos_example = f"<img>{pos}</img>What is the type of bird in the image? A.{pos_label} B.{neg_label}\nAnswer with the option's letter from the given choice directly. Answer: A\n"
        neg_example = f"<img>{neg}</img>What is the type of bird in the image? A.{pos_label} B.{neg_label}\nAnswer with the option's letter from the given choice directly. Answer: B\n"
        cur_query = f"<img>{query}</img>What is the type of bird in the image? A.{pos_label} B.{neg_label}\nAnswer with the option's letter from the given choice directly. Answer:"
        query_label = "A"
        return pos_example + neg_example + cur_query, query_label, -1
    else:
        pos_example = f"<img>{pos}</img>What is the type of bird in the image? A.{neg_label} B.{pos_label}\nAnswer with the option's letter from the given choice directly. Answer: B\n"
        neg_example = f"<img>{neg}</img>What is the type of bird in the image? A.{neg_label} B.{pos_label}\nAnswer with the option's letter from the given choice directly. Answer: A\n"
        cur_query = f"<img>{query}</img>What is the type of bird in the image? A.{neg_label} B.{pos_label}\nAnswer with the option's letter from the given choice directly. Answer:"
        query_label = "B"
        return neg_example + pos_example + cur_query, query_label, -1
    